{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Scaled Dot Product Attention (SDPA) in cuDNN Frontend\n",
    "\n",
    "This notebook is an example for the scaled dot product attention operator in cuDNN frontend. This operation computes scaled dot product attention as\n",
    "\n",
    "$\\text{Attention}(Q, K, V) = \\text{softmax}\\left(\\frac{QK^T}{\\sqrt{d}}\\right)V$\n",
    "\n",
    "using the FlashAttention-2 algorithm described in the paper [FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning](https://arxiv.org/abs/2307.08691). It is applicable for both training and inference phases, with an option to generate a stats tensor to be used for backwards training computation.\n",
    "\n",
    "The full documentation can be found in: [docs/operations/Attention.md#scaled-dot-product-attention](https://github.com/NVIDIA/cudnn-frontend/blob/main/docs/operations/Attention.md#scaled-dot-product-attention)\n",
    "\n",
    "The python test code for the full set of features can be found in: [test/python_fe/test.mhas.py](https://github.com/NVIDIA/cudnn-frontend/blob/main/test/python_fe/test_mhas.py)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cudnn-frontend/blob/main/samples/python/50_scaled_dot_product_attention.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Prerequisites and Setup\n",
    "This notebook requires an NVIDIA GPU A100 or newer. If running on Colab, go to Runtime → Change runtime type → Hardware accelerator and selct a GPU."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get_ipython().system('nvidia-smi')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get_ipython().system('export CUDA_VERSION=\"12.3\"')\n",
    "# get_ipython().system('pip install nvidia-cudnn-cu12')\n",
    "# get_ipython().system('CUDNN_PATH=`pip show nvidia-cudnn-cu12  | grep Location | cut -d\":\" -f2 | xargs`/nvidia/cudnn pip install git+https://github.com/NVIDIA/cudnn-frontend.git')\n",
    "# get_ipython().system('pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import cudnn\n",
    "import torch\n",
    "import math\n",
    "\n",
    "torch.manual_seed(42)\n",
    "handle = cudnn.create_handle()\n",
    "\n",
    "assert torch.cuda.is_available()\n",
    "assert torch.cuda.get_device_capability()[0] >= 8, \"SDPA operation is only supported on SM80 architecture (Ampere) or above\"\n",
    "assert cudnn.backend_version() >= 8903, \"SDPA operation is only supported cuDNN version 8.9.3 or above\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Problem sizes\n",
    "\n",
    "For this example, we will use the problem size from the original GPT-2 paper where:\n",
    " - maximum sequence length = 1024\n",
    " - hidden dim = number of heads $\\times$ embedding dimension per head = 12 $\\times$ 64 = 768"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "b = 4    # batch size\n",
    "h = 12   # query number of heads\n",
    "s = 1024 # maximum sequence length\n",
    "d = 64   # embedding dimension per head\n",
    "\n",
    "attn_scale = 1.0 / math.sqrt(d)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create the query, key, value, and output GPU tensors using PyTorch. However, the user may use any DLPack compatible tensor instead."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The tensors will have non-interleaved\n",
    "# BSHD (batch, sequence_length, num_head, dims_per_head) physical tensor layout\n",
    "# BHSD (batch, num_head, sequence_length, dims_per_head) logical tensor layout\n",
    "dims = (b, h, s, d)\n",
    "strides = (s * h * d, d, h * d, 1)\n",
    "\n",
    "q_gpu = torch.randn(b * s * h * d).half().cuda().as_strided(dims, strides)\n",
    "k_gpu = torch.randn(b * s * h * d).half().cuda().as_strided(dims, strides)\n",
    "v_gpu = torch.randn(b * s * h * d).half().cuda().as_strided(dims, strides)\n",
    "o_gpu = torch.empty(b * s * h * d).half().cuda().as_strided(dims, strides)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create the graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "graph = cudnn.pygraph(\n",
    "    io_data_type=cudnn.data_type.HALF,\n",
    "    intermediate_data_type=cudnn.data_type.FLOAT,\n",
    "    compute_data_type=cudnn.data_type.FLOAT,\n",
    ")\n",
    "\n",
    "q = graph.tensor_like(q_gpu)\n",
    "k = graph.tensor_like(k_gpu)\n",
    "v = graph.tensor_like(v_gpu)\n",
    "\n",
    "# the second return for the stats tensor is used for training only.\n",
    "# causal mask is enabled\n",
    "o, _ = graph.sdpa(\n",
    "    name=\"sdpa\",\n",
    "    q=q, k=k, v=v,\n",
    "    is_inference=True,\n",
    "    attn_scale=attn_scale,\n",
    "    use_causal_mask=True,\n",
    ")\n",
    "\n",
    "o.set_output(True).set_dim(dims).set_stride(strides)\n",
    "pass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Build the graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "graph.validate()\n",
    "graph.build_operation_graph()\n",
    "graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])\n",
    "graph.check_support()\n",
    "graph.build_plans()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Execute the graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "variant_pack = {\n",
    "    q: q_gpu,\n",
    "    k: k_gpu,\n",
    "    v: v_gpu,\n",
    "    o: o_gpu,\n",
    "}\n",
    "\n",
    "workspace = torch.empty(graph.get_workspace_size(), device=\"cuda\", dtype=torch.uint8)\n",
    "graph.execute(variant_pack, workspace)\n",
    "torch.cuda.synchronize()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Test cuDNN's output against PyTorch's and check correctness"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "q_ref = q_gpu.detach().float().requires_grad_()\n",
    "k_ref = k_gpu.detach().float().requires_grad_()\n",
    "v_ref = v_gpu.detach().float().requires_grad_()\n",
    "\n",
    "o_ref = torch.nn.functional.scaled_dot_product_attention(q_ref, k_ref, v_ref, is_causal=True, scale=attn_scale)\n",
    "torch.testing.assert_close(o_ref, o_gpu.float(), atol=5e-3, rtol=3e-3)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
